{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transposed convolution arithmetic\n",
    "\n",
    "A guide to convolution arithmetic for deep learning\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training DCGANs with Keras and TensorFlow\n",
    "\n",
    "A DCGANs implementation using the transposed convolution technique and the [Keras](https://keras.io/) library.\n",
    "\n",
    "\n",
    "### 1. Load data\n",
    "\n",
    "#### Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:47.400174Z",
     "start_time": "2018-06-19T00:02:46.597063Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:49.021637Z",
     "start_time": "2018-06-19T00:02:47.402154Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.datasets import mnist\n",
    "from keras.models import Sequential, Model\n",
    "from keras.layers import Input, Dense, LeakyReLU, BatchNormalization\n",
    "from keras.layers import Conv2D, Conv2DTranspose, Reshape, Flatten\n",
    "from keras.optimizers import Adam\n",
    "from keras import initializers\n",
    "from keras.utils import plot_model\n",
    "from keras import backend as K"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Getting the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:49.602716Z",
     "start_time": "2018-06-19T00:02:49.023733Z"
    }
   },
   "outputs": [],
   "source": [
    "# load dataset\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Reshaping and normalizing the inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:50.646294Z",
     "start_time": "2018-06-19T00:02:50.193290Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_train.shape (60000, 28, 28)\n",
      "X_train reshape: (60000, 28, 28, 1)\n"
     ]
    }
   ],
   "source": [
    "print('X_train.shape', X_train.shape)\n",
    "\n",
    "if K.image_data_format() == 'channels_first':\n",
    "    X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)\n",
    "    X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)\n",
    "    input_shape = (1, 28, 28)\n",
    "else:\n",
    "    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)\n",
    "    X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)\n",
    "    input_shape = (28, 28, 1)\n",
    "\n",
    "# the generator is using tanh activation, for which we need to preprocess \n",
    "# the image data into the range between -1 and 1.\n",
    "\n",
    "X_train = np.float32(X_train)\n",
    "X_train = (X_train / 255 - 0.5) * 2\n",
    "X_train = np.clip(X_train, -1, 1)\n",
    "\n",
    "print('X_train reshape:', X_train.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Define model\n",
    "\n",
    "#### Generator\n",
    "\n",
    "Our generator using the **inverse of convolution**, called transposed convolution. \n",
    "\n",
    "In between layers, BatchNormalization stabilizes learning. \n",
    "\n",
    "The activation function after each layer is a LeakyReLU. \n",
    "\n",
    "The output of the tanh at the last layer produces the fake image. \n",
    "\n",
    "![generator model](../img/generative.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:51.514030Z",
     "start_time": "2018-06-19T00:02:50.648356Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_43 (Dense)             (None, 12544)             1266944   \n",
      "_________________________________________________________________\n",
      "reshape_43 (Reshape)         (None, 7, 7, 256)         0         \n",
      "_________________________________________________________________\n",
      "batch_normalization_89 (Batc (None, 7, 7, 256)         1024      \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_92 (LeakyReLU)   (None, 7, 7, 256)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_transpose_51 (Conv2DT (None, 14, 14, 128)       131200    \n",
      "_________________________________________________________________\n",
      "batch_normalization_90 (Batc (None, 14, 14, 128)       512       \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_93 (LeakyReLU)   (None, 14, 14, 128)       0         \n",
      "=================================================================\n",
      "Total params: 1,399,680\n",
      "Trainable params: 1,398,912\n",
      "Non-trainable params: 768\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# latent space dimension\n",
    "latent_dim = 100\n",
    "\n",
    "# imagem dimension 28x28\n",
    "img_dim = 784\n",
    "\n",
    "init = initializers.RandomNormal(stddev=0.02)\n",
    "\n",
    "# Generator network\n",
    "generator = Sequential()\n",
    "\n",
    "# FC: 7x7x256\n",
    "generator.add(Dense(7*7*256, input_shape=(latent_dim,), kernel_initializer=init))\n",
    "generator.add(Reshape((7, 7, 256)))\n",
    "generator.add(BatchNormalization())\n",
    "generator.add(LeakyReLU(0.2))\n",
    "\n",
    "# Conv 1: 14x14x128\n",
    "generator.add(Conv2DTranspose(128, kernel_size=2, strides=2, padding='valid'))\n",
    "generator.add(BatchNormalization())\n",
    "generator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # Conv 2: 28x28x64\n",
    "# generator.add(Conv2DTranspose(64, kernel_size=3, strides=2, padding='same'))\n",
    "# generator.add(BatchNormalization())\n",
    "# generator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # Conv 3: 28x28x32\n",
    "# generator.add(Conv2DTranspose(32, kernel_size=3, strides=1, padding='same'))\n",
    "# generator.add(BatchNormalization())\n",
    "# generator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # Conv 4: 28x28x1\n",
    "# generator.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same',\n",
    "#                               activation='tanh'))\n",
    "\n",
    "generator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generator model visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:51.521078Z",
     "start_time": "2018-06-19T00:02:51.516058Z"
    }
   },
   "outputs": [],
   "source": [
    "# prints a summary representation of your model\n",
    "# generator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Discriminator\n",
    "\n",
    "Our discriminator is a **convolutional neural network** that takes a 28x28 image with 1 channel. The values in the image is expected to be between -1 and 1.\n",
    "\n",
    "It takes a digit image and classifies whether an image is real (1) or not (0).\n",
    "\n",
    "The last activation is sigmoid to tell us the probability of whether the input image is real or not.\n",
    "\n",
    "![discriminator model](../img/discriminative.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:51.602831Z",
     "start_time": "2018-06-19T00:02:51.522937Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_7 (Conv2D)            (None, 14, 14, 32)        320       \n",
      "_________________________________________________________________\n",
      "leaky_re_lu_97 (LeakyReLU)   (None, 14, 14, 32)        0         \n",
      "=================================================================\n",
      "Total params: 320\n",
      "Trainable params: 320\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# Discriminator network\n",
    "discriminator = Sequential()\n",
    "\n",
    "# Conv 1: 14x14x64\n",
    "discriminator.add(Conv2D(32, kernel_size=3, strides=2, padding='same',\n",
    "                         input_shape=(28, 28, 1), kernel_initializer=init))\n",
    "discriminator.add(LeakyReLU(0.2))\n",
    "\n",
    "# Conv 2:\n",
    "# discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))\n",
    "# discriminator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # Conv 3: \n",
    "# discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))\n",
    "# discriminator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # Conv 3: \n",
    "# discriminator.add(Conv2D(512, kernel_size=3, strides=1, padding='same'))\n",
    "# discriminator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # FC\n",
    "# discriminator.add(Flatten())\n",
    "# discriminator.add(LeakyReLU(0.2))\n",
    "\n",
    "# # Output\n",
    "# discriminator.add(Dense(1, activation='sigmoid'))\n",
    "\n",
    "discriminator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Discriminator model visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:51.608924Z",
     "start_time": "2018-06-19T00:02:51.604651Z"
    }
   },
   "outputs": [],
   "source": [
    "# prints a summary representation of your model\n",
    "discriminator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Compile model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compile discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:51.697630Z",
     "start_time": "2018-06-19T00:02:51.611026Z"
    }
   },
   "outputs": [],
   "source": [
    "# Optimizer\n",
    "opt = Adam(lr=0.0002, beta_1=0.5)\n",
    "\n",
    "discriminator.compile(opt, loss='binary_crossentropy',\n",
    "                      metrics=['binary_accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Combined network\n",
    "\n",
    "We connect the generator and the discriminator to make a DCGAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:52.080923Z",
     "start_time": "2018-06-19T00:02:51.699590Z"
    }
   },
   "outputs": [],
   "source": [
    "# d_g = discriminador(generador(z))\n",
    "discriminator.trainable = False\n",
    "\n",
    "z = Input(shape=(latent_dim,))\n",
    "img = generator(z)\n",
    "decision = discriminator(img)\n",
    "d_g = Model(inputs=z, outputs=decision)\n",
    "\n",
    "d_g.compile(opt, loss='binary_crossentropy',\n",
    "            metrics=['binary_accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### GAN model vizualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T00:02:52.087799Z",
     "start_time": "2018-06-19T00:02:52.083174Z"
    }
   },
   "outputs": [],
   "source": [
    "# prints a summary representation of your model\n",
    "d_g.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Fit model\n",
    "\n",
    "We train the discriminator and the generator in turn in a loop as follows:\n",
    "\n",
    "1. Set the discriminator trainable\n",
    "2. Train the discriminator with the real digit images and the images generated by the generator to classify the real and fake images.\n",
    "3. Set the discriminator non-trainable\n",
    "4. Train the generator as part of the GAN. We feed latent samples into the GAN and let the generator to produce digit images and use the discriminator to classify the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T02:02:33.125417Z",
     "start_time": "2018-06-19T00:02:52.089965Z"
    }
   },
   "outputs": [],
   "source": [
    "epochs = 100\n",
    "batch_size = 64\n",
    "smooth = 0.1\n",
    "\n",
    "real = np.ones(shape=(batch_size, 1))\n",
    "fake = np.zeros(shape=(batch_size, 1))\n",
    "\n",
    "d_loss = []\n",
    "g_loss = []\n",
    "\n",
    "for e in range(epochs + 1):\n",
    "    for i in range(len(X_train) // batch_size):\n",
    "        \n",
    "        # Train Discriminator weights\n",
    "        discriminator.trainable = True\n",
    "        \n",
    "        # Real samples\n",
    "        X_batch = X_train[i*batch_size:(i+1)*batch_size]\n",
    "        d_loss_real = discriminator.train_on_batch(x=X_batch,\n",
    "                                                   y=real * (1 - smooth))\n",
    "        \n",
    "        # Fake Samples\n",
    "        z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))\n",
    "        X_fake = generator.predict_on_batch(z)\n",
    "        d_loss_fake = discriminator.train_on_batch(x=X_fake, y=fake)\n",
    "         \n",
    "        # Discriminator loss\n",
    "        d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])\n",
    "        \n",
    "        # Train Generator weights\n",
    "        discriminator.trainable = False\n",
    "        g_loss_batch = d_g.train_on_batch(x=z, y=real)\n",
    "\n",
    "        print(\n",
    "            'epoch = %d/%d, batch = %d/%d, d_loss=%.3f, g_loss=%.3f' % (e + 1, epochs, i, len(X_train) // batch_size, d_loss_batch, g_loss_batch[0]),\n",
    "            100*' ',\n",
    "            end='\\r'\n",
    "        )\n",
    "    \n",
    "    d_loss.append(d_loss_batch)\n",
    "    g_loss.append(g_loss_batch[0])\n",
    "    print('epoch = %d/%d, d_loss=%.3f, g_loss=%.3f' % (e + 1, epochs, d_loss[-1], g_loss[-1]), 100*' ')\n",
    "\n",
    "    if e % 10 == 0:\n",
    "        samples = 10\n",
    "        x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, latent_dim)))\n",
    "\n",
    "        for k in range(samples):\n",
    "            plt.subplot(2, 5, k+1)\n",
    "            plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')\n",
    "            plt.xticks([])\n",
    "            plt.yticks([])\n",
    "\n",
    "        plt.tight_layout()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Evaluate model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-19T02:02:33.284689Z",
     "start_time": "2018-06-19T02:02:33.127648Z"
    }
   },
   "outputs": [],
   "source": [
    "# plotting the metrics\n",
    "plt.plot(d_loss)\n",
    "plt.plot(g_loss)\n",
    "plt.title('Model loss')\n",
    "plt.ylabel('Loss')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend(['Discriminator', 'Adversarial'], loc='center right')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "* [Generative Adversarial Networks or GANs](https://arxiv.org/abs/1406.2661)\n",
    "* [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks)\n",
    "* [THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/)\n",
    "* [Convolution](https://devblogs.nvidia.com/deep-learning-nutshell-core-concepts/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
